import collections

from LeetTool import TreeNode


class Solution:
    def correctBinaryTree(self, root: TreeNode) -> TreeNode:
        visited = {root}
        queue = collections.deque([(root, None)])
        while queue:
            node, father = queue.popleft()
            if node.right in visited:
                if father.left == node:
                    father.left = None
                else:
                    father.right = None
                return root
            if node.right:
                queue.append((node.right, node))
            if node.left:
                queue.append((node.left, node))
            visited.add(node)
        return root


if __name__ == "__main__":
    pass
